# Emma训练模块
import json

import torch
from nn.attention import Attention
from nn.encoder import Encoder
from nn.decoder import Decoder
from nn.seq2seq import Seq2Seq
import torch.nn as nn
import torch.optim as optim
import config
from lib.data_center import DataCenter
import lib.util as util
import numpy as np
from torch.utils.data import DataLoader
from lib.dataset import MyDataSet, Seq2seqDataSet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 如果是cuda有冲突可以考虑这行代码 强行不用cuda
# device = "cpu"
print("device = ", device)
dataset = DataCenter()
dicts = dataset.load_dict()
dataset.set_dict(dicts)


def get_all_token(datas):
    token2idx = {}
    token2idx.update({"EOF": 0})
    idx = 1
    for sentence in datas:
        stc = sentence['sentence']
        for word in stc:
            if word not in token2idx.keys():
                token2idx.update({word: idx})
                idx += 1
    return token2idx


def save_token(tokens: dict):
    fs = open(config.token_path, encoding='utf-8', mode='w')
    try:
        fs.write(json.dumps(tokens, ensure_ascii=False))
    finally:
        fs.close()


def train(datas):
    # my_data_loader = MyDataSet(datas)
    token2idx = get_all_token(datas)
    dataset = Seq2seqDataSet(datas, token2idx, long_limit=config.LONG_LIMIT)

    #
    # # 保存字典 具体看后面要求
    # save_token(token2idx)
    # raise Exception(123456)

    attn = Attention(config.ENC_HID_DIM, config.DEC_HID_DIM)
    enc = Encoder(len(token2idx), config.ENC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.ENC_DROPOUT)
    dec = Decoder(len(token2idx), config.DEC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.DEC_DROPOUT,
                  attn)
    # words = util.cut_word(data1)
    model = Seq2Seq(enc, dec, device).to(device)
    try:
        model.load_state_dict(torch.load(config.seq2seq_model))
    except Exception as e:
        print("读取模型出错", e)
        torch.save(model.state_dict(), config.seq2seq_model)

    # 优化器 损失
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True,
                            num_workers=0,
                            drop_last=True)

    all_loss = []
    for i in range(config.EPOCH):
        epoch_loss = 0
        j = 0
        epoch_loss_list = []
        for batch_idx, (src, target) in enumerate(dataloader):
            j = j + 1

            pred = model(src.cuda(), target.cuda())
            pred_dim = pred.shape[-1]

            # trg = [(trg len - 1) * batch size]
            # pred = [(trg len - 1) * batch size, pred_dim]
            target = target[1:].view(-1)
            pred = pred[1:].view(-1, pred_dim)

            loss = criterion(pred, target.cuda())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_loss_list.append(float(loss.item()))
            print(
                "epoch = {},第{}批数据 epoch_loss = {},single_loss = {}".format(i + 1, j, epoch_loss, loss.item()))



        # 记日志
        fs = open(config.log_path, encoding='utf-8', mode='a+')
        try:
            fs.write(json.dumps({
                'type': 'epoch_loss',
                'data': epoch_loss_list
            }))
            fs.write('\n')
        finally:
            fs.close()
        all_loss.append(epoch_loss)
        torch.save(model.state_dict(), config.seq2seq_model)
        print("保存模型和日志成功")

    fs = open(config.log_path, encoding='utf-8', mode='a+')
    try:
        fs.write(json.dumps({
            'type': 'all_loss',
            'data': all_loss
        }))
        fs.write('\n')
    finally:
        fs.close()
    print("保存数据成功")


def main():
    if __name__ == '__main__':
        datas = dataset.get_data_set()
        fanzxl = datas["xiaohuangji"]
        train(fanzxl)


main()
